{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Generative Adversarial Network\n",
    "\n",
    "In this notebook, we'll be building a generative adversarial network (GAN) trained on the MNIST dataset. From this, we'll be able to generate new handwritten digits!\n",
    "\n",
    "GANs were [first reported on](https://arxiv.org/abs/1406.2661) in 2014 from Ian Goodfellow and others in Yoshua Bengio's lab. Since then, GANs have exploded in popularity. Here are a few examples to check out:\n",
    "\n",
    "* [Pix2Pix](https://affinelayer.com/pixsrv/) \n",
    "* [CycleGAN & Pix2Pix in PyTorch, Jun-Yan Zhu](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix)\n",
    "* [A list of generative models](https://github.com/wiseodd/generative-models)\n",
    "\n",
    "The idea behind GANs is that you have two networks, a generator $G$ and a discriminator $D$, competing against each other. The generator makes \"fake\" data to pass to the discriminator. The discriminator also sees real training data and predicts if the data it's received is real or fake. \n",
    "> * The generator is trained to fool the discriminator, it wants to output data that looks _as close as possible_ to real, training data. \n",
    "* The discriminator is a classifier that is trained to figure out which data is real and which is fake. \n",
    "\n",
    "What ends up happening is that the generator learns to make data that is indistinguishable from real data to the discriminator.\n",
    "\n",
    "<img src='assets/gan_pipeline.png' width=70% />\n",
    "\n",
    "The general structure of a GAN is shown in the diagram above, using MNIST images as data. The latent sample is a random vector that the generator uses to construct its fake images. This is often called a **latent vector** and that vector space is called **latent space**. As the generator trains, it figures out how to map latent vectors to recognizable images that can fool the discriminator.\n",
    "\n",
    "If you're interested in generating only new images, you can throw out the discriminator after training. In this notebook, I'll show you how to define and train these adversarial networks in PyTorch and generate new images!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision import datasets\n",
    "import torchvision.transforms as transforms\n",
    "\n",
    "# number of subprocesses to use for data loading\n",
    "num_workers = 0\n",
    "# how many samples per batch to load\n",
    "batch_size = 64\n",
    "\n",
    "# convert data to torch.FloatTensor\n",
    "transform = transforms.ToTensor()\n",
    "\n",
    "# get the training datasets\n",
    "train_data = datasets.MNIST(root='data', train=True,\n",
    "                                   download=True, transform=transform)\n",
    "\n",
    "# prepare data loader\n",
    "train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size,\n",
    "                                           num_workers=num_workers)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Visualize the data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# obtain one batch of training images\n",
    "dataiter = iter(train_loader)\n",
    "images, labels = dataiter.next()\n",
    "images = images.numpy()\n",
    "\n",
    "# get one image from the batch\n",
    "img = np.squeeze(images[0])\n",
    "\n",
    "fig = plt.figure(figsize = (3,3)) \n",
    "ax = fig.add_subplot(111)\n",
    "ax.imshow(img, cmap='gray')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "# Define the Model\n",
    "\n",
    "A GAN is comprised of two adversarial networks, a discriminator and a generator."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Discriminator\n",
    "\n",
    "The discriminator network is going to be a pretty typical linear classifier. To make this network a universal function approximator, we'll need at least one hidden layer, and these hidden layers should have one key attribute:\n",
    "> All hidden layers will have a [Leaky ReLu](https://pytorch.org/docs/stable/nn.html#torch.nn.LeakyReLU) activation function applied to their outputs.\n",
    "\n",
    "<img src='assets/gan_network.png' width=70% />\n",
    "\n",
    "#### Leaky ReLu\n",
    "\n",
    "We should use a leaky ReLU to allow gradients to flow backwards through the layer unimpeded. A leaky ReLU is like a normal ReLU, except that there is a small non-zero output for negative input values.\n",
    "\n",
    "<img src='assets/leaky_relu.png' width=40% />\n",
    "\n",
    "#### Sigmoid Output\n",
    "\n",
    "We'll also take the approach of using a more numerically stable loss function on the outputs. Recall that we want the discriminator to output a value 0-1 indicating whether an image is _real or fake_. \n",
    "> We will ultimately use [BCEWithLogitsLoss](https://pytorch.org/docs/stable/nn.html#bcewithlogitsloss), which combines a `sigmoid` activation function **and** binary cross entropy loss in one function. \n",
    "\n",
    "So, our final output layer should not have any activation function applied to it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "class Discriminator(nn.Module):\n",
    "\n",
    "    def __init__(self, input_size, hidden_dim, output_size):\n",
    "        super(Discriminator, self).__init__()\n",
    "        \n",
    "        # define all layers\n",
    "        \n",
    "        \n",
    "    def forward(self, x):\n",
    "        # flatten image\n",
    "        \n",
    "        # pass x through all layers\n",
    "        # apply leaky relu activation to all hidden layers\n",
    "\n",
    "        return x\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Generator\n",
    "\n",
    "The generator network will be almost exactly the same as the discriminator network, except that we're applying a [tanh activation function](https://pytorch.org/docs/stable/nn.html#tanh) to our output layer.\n",
    "\n",
    "#### tanh Output\n",
    "The generator has been found to perform the best with $tanh$ for the generator output, which scales the output to be between -1 and 1, instead of 0 and 1. \n",
    "\n",
    "<img src='assets/tanh_fn.png' width=40% />\n",
    "\n",
    "Recall that we also want these outputs to be comparable to the *real* input pixel values, which are read in as normalized values between 0 and 1. \n",
    "> So, we'll also have to **scale our real input images to have pixel values between -1 and 1** when we train the discriminator. \n",
    "\n",
    "I'll do this in the training loop, later on."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Generator(nn.Module):\n",
    "\n",
    "    def __init__(self, input_size, hidden_dim, output_size):\n",
    "        super(Generator, self).__init__()\n",
    "        \n",
    "        # define all layers\n",
    "        \n",
    "\n",
    "    def forward(self, x):\n",
    "        # pass x through all layers\n",
    "        \n",
    "        # final layer should have tanh applied\n",
    "        \n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model hyperparameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Discriminator hyperparams\n",
    "\n",
    "# Size of input image to discriminator (28*28)\n",
    "input_size = \n",
    "# Size of discriminator output (real or fake)\n",
    "d_output_size = \n",
    "# Size of *last* hidden layer in the discriminator\n",
    "d_hidden_size = \n",
    "\n",
    "# Generator hyperparams\n",
    "\n",
    "# Size of latent vector to give to generator\n",
    "z_size = \n",
    "# Size of generator output (generated image)\n",
    "g_output_size = \n",
    "# Size of *first* hidden layer in the generator\n",
    "g_hidden_size = "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Build complete network\n",
    "\n",
    "Now we're instantiating the discriminator and generator from the classes defined above. Make sure you've passed in the correct input arguments."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# instantiate discriminator and generator\n",
    "D = Discriminator(input_size, d_hidden_size, d_output_size)\n",
    "G = Generator(z_size, g_hidden_size, g_output_size)\n",
    "\n",
    "# check that they are as you expect\n",
    "print(D)\n",
    "print()\n",
    "print(G)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "## Discriminator and Generator Losses\n",
    "\n",
    "Now we need to calculate the losses. \n",
    "\n",
    "### Discriminator Losses\n",
    "\n",
    "> * For the discriminator, the total loss is the sum of the losses for real and fake images, `d_loss = d_real_loss + d_fake_loss`. \n",
    "* Remember that we want the discriminator to output 1 for real images and 0 for fake images, so we need to set up the losses to reflect that.\n",
    "\n",
    "<img src='assets/gan_pipeline.png' width=70% />\n",
    "\n",
    "The losses will by binary cross entropy loss with logits, which we can get with [BCEWithLogitsLoss](https://pytorch.org/docs/stable/nn.html#bcewithlogitsloss). This combines a `sigmoid` activation function **and** and binary cross entropy loss in one function.\n",
    "\n",
    "For the real images, we want `D(real_images) = 1`. That is, we want the discriminator to classify the the real images with a label = 1, indicating that these are real. To help the discriminator generalize better, the labels are **reduced a bit from 1.0 to 0.9**. For this, we'll use the parameter `smooth`; if True, then we should smooth our labels. In PyTorch, this looks like `labels = torch.ones(size) * 0.9`\n",
    "\n",
    "The discriminator loss for the fake data is similar. We want `D(fake_images) = 0`, where the fake images are the _generator output_, `fake_images = G(z)`. \n",
    "\n",
    "### Generator Loss\n",
    "\n",
    "The generator loss will look similar only with flipped labels. The generator's goal is to get `D(fake_images) = 1`. In this case, the labels are **flipped** to represent that the generator is trying to fool the discriminator into thinking that the images it generates (fakes) are real!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Calculate losses\n",
    "def real_loss(D_out, smooth=False):\n",
    "    # compare logits to real labels\n",
    "    # smooth labels if smooth=True\n",
    "    loss = \n",
    "    return loss\n",
    "\n",
    "def fake_loss(D_out):\n",
    "    # compare logits to fake labels\n",
    "    loss = \n",
    "    return loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Optimizers\n",
    "\n",
    "We want to update the generator and discriminator variables separately. So, we'll define two separate Adam optimizers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import torch.optim as optim\n",
    "\n",
    "# learning rate for optimizers\n",
    "lr = 0.002\n",
    "\n",
    "# Create optimizers for the discriminator and generator\n",
    "d_optimizer = \n",
    "g_optimizer = "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "## Training\n",
    "\n",
    "Training will involve alternating between training the discriminator and the generator. We'll use our functions `real_loss` and `fake_loss` to help us calculate the discriminator losses in all of the following cases.\n",
    "\n",
    "### Discriminator training\n",
    "1. Compute the discriminator loss on real, training images        \n",
    "2. Generate fake images\n",
    "3. Compute the discriminator loss on fake, generated images     \n",
    "4. Add up real and fake loss\n",
    "5. Perform backpropagation + an optimization step to update the discriminator's weights\n",
    "\n",
    "### Generator training\n",
    "1. Generate fake images\n",
    "2. Compute the discriminator loss on fake images, using **flipped** labels!\n",
    "3. Perform backpropagation + an optimization step to update the generator's weights\n",
    "\n",
    "#### Saving Samples\n",
    "\n",
    "As we train, we'll also print out some loss statistics and save some generated \"fake\" samples."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import pickle as pkl\n",
    "\n",
    "# training hyperparams\n",
    "num_epochs = 40\n",
    "\n",
    "# keep track of loss and generated, \"fake\" samples\n",
    "samples = []\n",
    "losses = []\n",
    "\n",
    "print_every = 400\n",
    "\n",
    "# Get some fixed data for sampling. These are images that are held\n",
    "# constant throughout training, and allow us to inspect the model's performance\n",
    "sample_size=16\n",
    "fixed_z = np.random.uniform(-1, 1, size=(sample_size, z_size))\n",
    "fixed_z = torch.from_numpy(fixed_z).float()\n",
    "\n",
    "# train the network\n",
    "D.train()\n",
    "G.train()\n",
    "for epoch in range(num_epochs):\n",
    "    \n",
    "    for batch_i, (real_images, _) in enumerate(train_loader):\n",
    "                \n",
    "        batch_size = real_images.size(0)\n",
    "        \n",
    "        ## Important rescaling step ## \n",
    "        real_images = real_images*2 - 1  # rescale input images from [0,1) to [-1, 1)\n",
    "        \n",
    "        # ============================================\n",
    "        #            TRAIN THE DISCRIMINATOR\n",
    "        # ============================================\n",
    "                \n",
    "        # 1. Train with real images\n",
    "\n",
    "        # Compute the discriminator losses on real images\n",
    "        # use smoothed labels\n",
    "        \n",
    "        # 2. Train with fake images\n",
    "        \n",
    "        # Generate fake images\n",
    "        z = np.random.uniform(-1, 1, size=(batch_size, z_size))\n",
    "        z = torch.from_numpy(z).float()\n",
    "        fake_images = G(z)\n",
    "        \n",
    "        # Compute the discriminator losses on fake images        \n",
    "        \n",
    "        # add up real and fake losses and perform backprop\n",
    "        d_loss = \n",
    "        \n",
    "        \n",
    "        # =========================================\n",
    "        #            TRAIN THE GENERATOR\n",
    "        # =========================================\n",
    "        \n",
    "        # 1. Train with fake images and flipped labels\n",
    "        \n",
    "        # Generate fake images\n",
    "        \n",
    "        # Compute the discriminator losses on fake images \n",
    "        # using flipped labels!\n",
    "        \n",
    "        # perform backprop\n",
    "        g_loss = \n",
    "        \n",
    "\n",
    "        # Print some loss stats\n",
    "        if batch_i % print_every == 0:\n",
    "            # print discriminator and generator loss\n",
    "            print('Epoch [{:5d}/{:5d}] | d_loss: {:6.4f} | g_loss: {:6.4f}'.format(\n",
    "                    epoch+1, num_epochs, d_loss.item(), g_loss.item()))\n",
    "\n",
    "    \n",
    "    ## AFTER EACH EPOCH##\n",
    "    # append discriminator loss and generator loss\n",
    "    losses.append((d_loss.item(), g_loss.item()))\n",
    "    \n",
    "    # generate and save sample, fake images\n",
    "    G.eval() # eval mode for generating samples\n",
    "    samples_z = G(fixed_z)\n",
    "    samples.append(samples_z)\n",
    "    G.train() # back to train mode\n",
    "\n",
    "\n",
    "# Save training generator samples\n",
    "with open('train_samples.pkl', 'wb') as f:\n",
    "    pkl.dump(samples, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training loss\n",
    "\n",
    "Here we'll plot the training losses for the generator and discriminator, recorded after each epoch."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots()\n",
    "losses = np.array(losses)\n",
    "plt.plot(losses.T[0], label='Discriminator')\n",
    "plt.plot(losses.T[1], label='Generator')\n",
    "plt.title(\"Training Losses\")\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Generator samples from training\n",
    "\n",
    "Here we can view samples of images from the generator. First we'll look at the images we saved during training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# helper function for viewing a list of passed in sample images\n",
    "def view_samples(epoch, samples):\n",
    "    fig, axes = plt.subplots(figsize=(7,7), nrows=4, ncols=4, sharey=True, sharex=True)\n",
    "    for ax, img in zip(axes.flatten(), samples[epoch]):\n",
    "        img = img.detach()\n",
    "        ax.xaxis.set_visible(False)\n",
    "        ax.yaxis.set_visible(False)\n",
    "        im = ax.imshow(img.reshape((28,28)), cmap='Greys_r')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load samples from generator, taken while training\n",
    "with open('train_samples.pkl', 'rb') as f:\n",
    "    samples = pkl.load(f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "These are samples from the final training epoch. You can see the generator is able to reproduce numbers like 1, 7, 3, 2. Since this is just a sample, it isn't representative of the full range of images this generator can make."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# -1 indicates final epoch's samples (the last in the list)\n",
    "view_samples(-1, samples)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Below I'm showing the generated images as the network was training, every 10 epochs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rows = 10 # split epochs into 10, so 100/10 = every 10 epochs\n",
    "cols = 6\n",
    "fig, axes = plt.subplots(figsize=(7,12), nrows=rows, ncols=cols, sharex=True, sharey=True)\n",
    "\n",
    "for sample, ax_row in zip(samples[::int(len(samples)/rows)], axes):\n",
    "    for img, ax in zip(sample[::int(len(sample)/cols)], ax_row):\n",
    "        img = img.detach()\n",
    "        ax.imshow(img.reshape((28,28)), cmap='Greys_r')\n",
    "        ax.xaxis.set_visible(False)\n",
    "        ax.yaxis.set_visible(False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It starts out as all noise. Then it learns to make only the center white and the rest black. You can start to see some number like structures appear out of the noise like 1s and 9s."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Sampling from the generator\n",
    "\n",
    "We can also get completely new images from the generator by using the checkpoint we saved after training. **We just need to pass in a new latent vector $z$ and we'll get new samples**!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# randomly generated, new latent vectors\n",
    "sample_size=16\n",
    "rand_z = np.random.uniform(-1, 1, size=(sample_size, z_size))\n",
    "rand_z = torch.from_numpy(rand_z).float()\n",
    "\n",
    "G.eval() # eval mode\n",
    "# generated samples\n",
    "rand_images = G(rand_z)\n",
    "\n",
    "# 0 indicates the first set of samples in the passed in list\n",
    "# and we only have one batch of samples, here\n",
    "view_samples(0, [rand_images])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
